# ==============================================================================
# Snakefile — 10X snMultiome Analysis Pipeline (Gastric WNT / RZ vs RZK)
# ==============================================================================
configfile: "config.yaml"

OUTPUT = config["output_dir"]

rule all:
    input:
        f"{OUTPUT}/01_qc/wild_type_qc.rds",
        f"{OUTPUT}/01_qc/mutant_qc.rds",
        f"{OUTPUT}/02_peaks/wild_type_peaks.rds",
        f"{OUTPUT}/02_peaks/mutant_peaks.rds",
        f"{OUTPUT}/03_normalised/wild_type_normalised.rds",
        f"{OUTPUT}/03_normalised/mutant_normalised.rds",
        f"{OUTPUT}/04_integrated/RZ_RZK_integrated.rds",
        f"{OUTPUT}/05_figures/fig3A_UMAP.pdf",
        f"{OUTPUT}/05_figures/fig3B_dotplot.pdf",
        f"{OUTPUT}/05_figures/fig3C_dittobarplot.pdf",
        f"{OUTPUT}/05_figures/fig3D_Wnt7b_featureplot.pdf",
        f"{OUTPUT}/05_figures/figS3A_marker_featureplots.pdf",
        f"{OUTPUT}/05_figures/figS3B_wnt_family_violin.pdf",
        f"{OUTPUT}/06_differential/DEG_MAST_results.csv",
        f"{OUTPUT}/06_differential/KRAS_signaling_overlap.csv",
        f"{OUTPUT}/06_differential/DA_peaks_Wnt7_vs_Lgr5_mutant.csv",
        f"{OUTPUT}/06_differential/DA_peaks_Wnt7_vs_Lgr5_wildtype.csv",
        f"{OUTPUT}/06_differential/Wnt7b_expression_stats.csv",
        f"{OUTPUT}/06_differential/Wnt7b_accessibility_stats.csv",
        f"{OUTPUT}/07_motifs/motif_enrichment_Wnt_vs_Lgr5_mutant.csv",
        f"{OUTPUT}/07_motifs/motif_enrichment_Wnt_vs_Lgr5_wildtype.csv",
        f"{OUTPUT}/07_motifs/fig3I_motif_barplot_mutant.pdf",
        f"{OUTPUT}/07_motifs/fig3I_motif_barplot_wildtype.pdf"

rule load_and_qc:
    output:
        wild = f"{OUTPUT}/01_qc/wild_type_qc.rds",
        mutant = f"{OUTPUT}/01_qc/mutant_qc.rds",
        qc_violin_wild = f"{OUTPUT}/01_qc/qc_violin_wild_type.pdf",
        qc_violin_mutant = f"{OUTPUT}/01_qc/qc_violin_mutant.pdf"
    threads: config.get("threads", 20)
    log:
        f"{OUTPUT}/logs/01_load_and_qc.log"
    shell:
        "Rscript /pipeline/scripts/01_load_and_qc.R /pipeline/config.yaml 2>&1 | tee {log}"

rule macs2_peaks:
    input:
        wild = rules.load_and_qc.output.wild,
        mutant = rules.load_and_qc.output.mutant
    output:
        wild = f"{OUTPUT}/02_peaks/wild_type_peaks.rds",
        mutant = f"{OUTPUT}/02_peaks/mutant_peaks.rds"
    threads: config.get("threads", 20)
    log:
        f"{OUTPUT}/logs/02_macs2_peaks.log"
    shell:
        "Rscript /pipeline/scripts/02_macs2_peaks.R /pipeline/config.yaml 2>&1 | tee {log}"

rule normalise:
    input:
        wild = rules.macs2_peaks.output.wild,
        mutant = rules.macs2_peaks.output.mutant
    output:
        wild = f"{OUTPUT}/03_normalised/wild_type_normalised.rds",
        mutant = f"{OUTPUT}/03_normalised/mutant_normalised.rds"
    threads: config.get("threads", 20)
    log:
        f"{OUTPUT}/logs/03_normalise.log"
    shell:
        "Rscript /pipeline/scripts/03_normalize.R /pipeline/config.yaml 2>&1 | tee {log}"

rule integrate_cluster:
    input:
        wild = rules.normalise.output.wild,
        mutant = rules.normalise.output.mutant
    output:
        combined = f"{OUTPUT}/04_integrated/RZ_RZK_integrated.rds"
    threads: config.get("threads", 20)
    log:
        f"{OUTPUT}/logs/04_integrate_cluster.log"
    shell:
        "Rscript /pipeline/scripts/04_integrate_cluster.R /pipeline/config.yaml 2>&1 | tee {log}"

rule visualisation:
    input:
        obj = rules.integrate_cluster.output.combined
    output:
        umap = f"{OUTPUT}/05_figures/fig3A_UMAP.pdf",
        dotplot = f"{OUTPUT}/05_figures/fig3B_dotplot.pdf",
        ditto = f"{OUTPUT}/05_figures/fig3C_dittobarplot.pdf",
        wnt7b = f"{OUTPUT}/05_figures/fig3D_Wnt7b_featureplot.pdf",
        markers = f"{OUTPUT}/05_figures/figS3A_marker_featureplots.pdf",
        wnt_vln = f"{OUTPUT}/05_figures/figS3B_wnt_family_violin.pdf"
    threads: config.get("threads", 20)
    log:
        f"{OUTPUT}/logs/05_visualisation.log"
    shell:
        "Rscript /pipeline/scripts/05_visualization.R /pipeline/config.yaml 2>&1 | tee {log}"

rule differential:
    input:
        obj = rules.integrate_cluster.output.combined
    output:
        deg = f"{OUTPUT}/06_differential/DEG_MAST_results.csv",
        kras = f"{OUTPUT}/06_differential/KRAS_signaling_overlap.csv",
        da_mut = f"{OUTPUT}/06_differential/DA_peaks_Wnt7_vs_Lgr5_mutant.csv",
        da_wt = f"{OUTPUT}/06_differential/DA_peaks_Wnt7_vs_Lgr5_wildtype.csv",
        wnt_expr = f"{OUTPUT}/06_differential/Wnt7b_expression_stats.csv",
        wnt_acc = f"{OUTPUT}/06_differential/Wnt7b_accessibility_stats.csv"
    threads: config.get("threads", 20)
    log:
        f"{OUTPUT}/logs/06_differential.log"
    shell:
        "Rscript /pipeline/scripts/06_differential.R /pipeline/config.yaml 2>&1 | tee {log}"

rule motif_enrichment:
    input:
        obj = rules.integrate_cluster.output.combined,
        da_mut = rules.differential.output.da_mut,
        da_wt = rules.differential.output.da_wt
    output:
        motif_mut = f"{OUTPUT}/07_motifs/motif_enrichment_Wnt_vs_Lgr5_mutant.csv",
        motif_wt = f"{OUTPUT}/07_motifs/motif_enrichment_Wnt_vs_Lgr5_wildtype.csv",
        plot_mut = f"{OUTPUT}/07_motifs/fig3I_motif_barplot_mutant.pdf",
        plot_wt = f"{OUTPUT}/07_motifs/fig3I_motif_barplot_wildtype.pdf"
    threads: 1
    log:
        f"{OUTPUT}/logs/07_motif_enrichment.log"
    shell:
        "Rscript /pipeline/scripts/07_motif_enrichment.R /pipeline/config.yaml 2>&1 | tee {log}"